import torch

from load_local_refinement import get_gnn_inputs_local_refinement
from losses import gnn_compute_acc_ari_nmi_multiclass


def run_multi_refinement(
    gnn_model,
    W_np,
    init_labels,
    true_labels,
    args,
    device,
    num_iters: int = 1
):
    current_labels = init_labels
    history = []

    for t in range(num_iters):
        # --- �������� ---
        WW, x = get_gnn_inputs_local_refinement(
            W_np, args.J_second, current_labels, args.n_classes
        )
        WW, x = WW.to(device), x.to(device)

        # --- ǰ�򴫲� ---
        pred = gnn_model(
            WW.type(torch.float32),
            x.type(torch.float32)
        )

        # --- ����ָ�� ---
        acc, best_pred, ari, nmi = gnn_compute_acc_ari_nmi_multiclass(
            pred, true_labels, args.n_classes
        )

        # ���� current_labels Ϊ��һ������
        current_labels = best_pred

        history.append({
            "iter": t + 1,
            "pred_label": best_pred,
            "pred":pred,
            "acc": acc,
            "ari": ari,
            "nmi": nmi
        })

    return history


def run_refinement_chain(
    gnn_second_period,
    W_np,
    init_labels,
    true_labels,
    args,
    device,
    total_iters: int = 5,
    verbose: bool = True,
):

    # ���� total_iters ��
    hist = run_multi_refinement(
        gnn_second_period,
        W_np,
        init_labels,
        true_labels,
        args,
        device,
        num_iters=total_iters
    )

    # ���ÿһ��ָ��
    if verbose:
        for i, r in enumerate(hist, start=1):
            print(f"[Iter {i}] acc={r['acc']:.4f}, ari={r['ari']:.4f}, nmi={r['nmi']:.4f}")

    out = {
        "hist": hist,
        "first_iter": hist[0] if len(hist) >= 1 else None,
        "final": hist[-1] if len(hist) >= 1 else None,
    }
    return out